import numpy as np
import torch
from torch import Tensor

from src.envs.base_environment import ContinuousEnvironment  



class GridEnvironment(ContinuousEnvironment):
    """
    ### Description
    
    The grid environment is a 2D environment where the state is a 2D vector in the x-y plane. 
    The action is a 2D vector representing the change in the (x,y) coordinates. 
    The goal is to sample a grid reward distribution with high reward peaks at (-10,-10), (-10, 10), (10, 10), (10, -10). 

    ### Action Space

    | Num | Action    | Min | Max|
    |-----|-----------|-----|----|
    | 0   | Delta x   | -10 | 10 |
    | 1   | Delta y   | -10 | 10 |
    

    ### Observation Space

    | Num | Observation | Min | Max |
    |-----|-------------|-----|-----|
    | 0   | x           | -15 | 15  |
    | 1   | y           | -15 | 15  |

    ### Rewards

    The reward density is defined as:
    r(x,y) = \sum_{i=1}^{4} exp(-((x - x_i)^2 + (y - y_i)^2) / 2)  where (x_i, y_i) are the peak coordinates, i.e. (-10, -10), (-10, 10), (10, 10), (10, -10)

    ### Policy Parameterisation

    The policy is parameterised as a mixture model with `mixture_dim` components.
    The mixture is a bivariate mixture of Gaussians.

    ### Arguments

    - `max_policy_std`: Maximum sigma parameter for the Gaussian distribution.
    - `min_policy_std`: Minimum sigma parameter for the Gaussian distribution.
    - `num_grid_points`: Number of grid points in each dimension of the state space.
    - `mixture_dim`: Number of components in the Gaussian mixture model in the parameterisation of the policy.
    """
    
    def __init__(
            self, 
            config):
        self._init_required_params(config)
        lower_bound = torch.tensor([-15, -15], device=config["device"])
        upper_bound = torch.tensor([15, 15], device=config["device"])
        means = torch.tensor([[-self.edge_size, -self.edge_size], [-self.edge_size, self.edge_size], [self.edge_size, self.edge_size], [self.edge_size, -self.edge_size]], device=config["device"])
        self.mixture = [
            torch.distributions.MultivariateNormal(means[i], torch.eye(2, device=config["device"])) for i in range(4)
        ]
        super().__init__(config,
                         dim = 2,
                         feature_dim = 2,
                         angle_dim = [False, False],
                         action_dim = 2,
                         lower_bound = lower_bound,
                         upper_bound = upper_bound,
                         mixture_dim = config["env"]["mixture_dim"],
                         output_dim = 5 * config["env"]["mixture_dim"])  

    def _init_required_params(self, config):
        required_params = ["max_policy_std", "min_policy_std"]
        assert all([param in config["env"] for param in required_params]), f"Missing required parameters: {required_params}"
        self.max_policy_std = config["env"]["max_policy_std"]
        self.min_policy_std = config["env"]["min_policy_std"]
        self.edge_size = config["env"]["edge_size"]

    def log_reward(self, x):
        return torch.logsumexp(torch.stack([m.log_prob(x) for m in self.mixture], 0), 0)
    
    def step(self, x: Tensor, action: Tensor):
        """Takes a step in the environment given an action. x is the current state and action is the action to take. Returns the new state."""
        # x: [batch_size, 2]
        # action: [batch_size]
        new_x = torch.zeros_like(x)
        new_x[:, 0] = x[:, 0] + action[:, 0]  # Update x coordinate
        new_x[:, 1] = x[:, 1] + action[:, 1]  # Update y coordinate
        new_x[:, 2] = x[:, 2] + 1             # Increment step counter

        return new_x
    
    def backward_step(self, x: Tensor, action: Tensor):
        """Takes a backward step in the environment given an action. x is the current state and action that had been taken to reach x. Returns the previous state."""
        # x: [batch_size, 2]
        # action: [batch_size]
        new_x = torch.zeros_like(x)
        new_x[:, 0] = x[:, 0] - action[:, 0] # Update x coordinate
        new_x[:, 1] = x[:, 1] - action[:, 1] # Update y coordinate
        new_x[:, 2] = x[:, 2] - 1            # Decrement step counter

        return new_x
    
    def compute_initial_action(self, first_state):
        return (first_state - self.init_value)
    
    def _init_policy_dist(self, param_dict):
        """Initialises a mixture of von Mises distributions. Used for policy parameterisation."""
        mus_x, mus_y, sigmas_x, sigmas_y, weights = param_dict["mus_x"], param_dict["mus_y"], param_dict["sigmas_x"], param_dict["sigmas_y"], param_dict["weights"]

        # Create the means tensor
        mus = torch.stack([mus_x, mus_y], dim=-1)

        # Create the covariances (assuming diagonal covariance for simplicity)
        sigmas = torch.stack([sigmas_x, sigmas_y], dim=-1)
        covs = torch.diag_embed(sigmas)

        # Define the mixture components
        mix = torch.distributions.Categorical(weights)
        components = torch.distributions.MultivariateNormal(mus, covariance_matrix=covs)

        # Combine into a MixtureSameFamily distribution
        return torch.distributions.MixtureSameFamily(mix, components)
    
    def postprocess_params(self, params):
        """Postprocesses the parameters of the policy distribution to ensure they are within the correct range(s)."""
        # Restrict mu_x and mu_y to the range (-pi, pi)
        mu_x_params, mu_y_params, sigma_x_params, sigma_y_params, weight_params = params[:, :self.mixture_dim], params[:, self.mixture_dim: 2 * self.mixture_dim], params[:, 2 * self.mixture_dim: 3 * self.mixture_dim], params[:, 3 * self.mixture_dim: 4 * self.mixture_dim], params[:, 4 * self.mixture_dim:]

        mus_x = torch.sigmoid(mu_x_params) * (2 * self.edge_size) - self.edge_size
        mus_y = torch.sigmoid(mu_y_params) * (2 * self.edge_size) - self.edge_size

        sigmas_x = torch.sigmoid(sigma_x_params) * (self.max_policy_std - self.min_policy_std) + self.min_policy_std
        sigmas_y = torch.sigmoid(sigma_y_params) * (self.max_policy_std - self.min_policy_std) + self.min_policy_std

        weights = torch.softmax(weight_params, dim=1)
        param_dict = {"mus_x": mus_x, "mus_y": mus_y, "sigmas_x": sigmas_x, "sigmas_y": sigmas_y, "weights": weights}
        
        return param_dict
    
    def add_noise(self, param_dict: dict, off_policy_noise: float):
        """Adds noise to the policy parameters for noisy exploration."""
        param_dict["sigmas_x"] += off_policy_noise
        param_dict["sigmas_y"] += off_policy_noise

        return param_dict
    